{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/federated/experimental/tutorials/jax_support\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/federated/blob/master/docs/experimental/tutorials/jax_support.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/federated/blob/master/docs/experimental/tutorials/jax_support.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/federated/docs/experimental/tutorials/jax_support.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ANjgx215o-Ik"
      },
      "source": [
        "# Experimental support for JAX in TFF\n",
        "\n",
        "In addition to being a part of the TensorFlow ecosystem, TFF aims to enable\n",
        "interoperability with other frontend and backend ML frameworks. At the moment,\n",
        "support for other ML frameworks is still in the incubation phase, and the APIs\n",
        "and the functionality supported may change (largely as a function of demand from\n",
        "the users of TFF). This tutorial describes how to use TFF with JAX as an\n",
        "alternative ML frontend, and the XLA compiler as an alternative backend. The\n",
        "examples shown here are based on an entirely native JAX/XLA stack, end-to-end.\n",
        "The possibility of mixing code across frameworks (e.g., JAX with TensorFlow)\n",
        "will be discussed in one of the future tutorials.\n",
        "\n",
        "As always, we welcome your contributions. If support for JAX/XLA or the ability\n",
        "to interoperate with other ML frameworks is important for you, please consider\n",
        "helping us evolve these capabilities towards parity with the remainder of TFF.\n",
        "\n",
        "## Before we begin\n",
        "\n",
        "Please consult the main body of TFF documentation for how to configure your\n",
        "environment. Depending on where you are running this tutorial, you may want to\n",
        "uncomment and run some or all of the code below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WuK9Wi9hT4Ch"
      },
      "outputs": [],
      "source": [
        "# !pip install --quiet --upgrade tensorflow-federated-nightly\n",
        "# !pip install --quiet --upgrade nest-asyncio\n",
        "# import nest_asyncio\n",
        "# nest_asyncio.apply()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G0mcgo-lm6q9"
      },
      "source": [
        "This tutorial also assumes you have reviewed TFF's primary TensorFlow\n",
        "tutorials, and that you are familiar with the core TFF concepts. If you have\n",
        "not done this yet, please consider reviewing at least one of them."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZgqoOutssD-e"
      },
      "source": [
        "## JAX computations\n",
        "\n",
        "Support for JAX in TFF is designed to be symmetric with the manner in which TFF\n",
        "interoperates with TensorFlow, starting with imports:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NacZ6Aw6lZ_v"
      },
      "outputs": [],
      "source": [
        "import jax\n",
        "import numpy as np\n",
        "import tensorflow_federated as tff"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1yQB7evJlI_Z"
      },
      "source": [
        "Also, just like with TensorFlow, the foundation for expressing any TFF code is\n",
        "the logic that runs locally. You can express this logic in JAX, as shown below,\n",
        "using the `@tff.experimental.jax_computation` wrapper. It behaves similarly to\n",
        "the `@tff.tf_computation` that by now your are familiar with. Let's start with\n",
        "something simple, e.g., a computation that adds two integers:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Mu3ErwoOmqIG"
      },
      "outputs": [],
      "source": [
        "@tff.experimental.jax_computation(np.int32, np.int32)\n",
        "def add_numbers(x, y):\n",
        "  return jax.numpy.add(x, y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A-YIXXa4nkow"
      },
      "source": [
        "You can use the JAX computation defined above just like you would normally use\n",
        "a TFF computation. For example, you can check its type signature, as follows:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 35
        },
        "executionInfo": {
          "elapsed": 31,
          "status": "ok",
          "timestamp": 1614320155185
        },
        "id": "qTpGm32Onj-U",
        "outputId": "41b803cb-8a5d-436b-8f2e-595f89f807fb"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'(\u003cx=int32,y=int32\u003e -\u003e int32)'"
            ]
          },
          "execution_count": 4,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(add_numbers.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2J8Pphy5rAJR"
      },
      "source": [
        "Note that we used `np.int32` to define the type of arguments. TFF does not\n",
        "distinguish between Numpy types (such as `np.int32`) and TensorFlow type\n",
        "(such as `tf.int32`). From TFF's perspective, they're just ways to refer to\n",
        "the same thing."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gff68zawrdB2"
      },
      "source": [
        "Now, remember that TFF is not Python (and if this doesn't ring a bell, please\n",
        "review some of our earlier tutorials, e.g., on custom algorithms). You can\n",
        "use the `@tff.experimental.jax_computation` wrapper with any JAX code that can\n",
        "be traced and serialized, i.e., with code that you would normally annotate\n",
        "with `@jax.jit` expected to be compiled into XLA (but you don't need to\n",
        "actually use the `@jax.jit` annotation to embed your JAX code in TFF)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w--t1lDesKNg"
      },
      "source": [
        "Indeed, under the hood, TFF instantly compiles JAX computations to\n",
        "XLA. You can check this for yourself by manually extracting and\n",
        "printing the serialized XLA code from `add_numbers`, as follows:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 35
        },
        "executionInfo": {
          "elapsed": 37,
          "status": "ok",
          "timestamp": 1614320155293
        },
        "id": "HCJlOjN3qumu",
        "outputId": "377be7b1-3024-4da0-877d-bfd35c5da2da"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'xla'"
            ]
          },
          "execution_count": 5,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "comp_pb = tff.framework.serialize_computation(add_numbers)\n",
        "comp_pb.WhichOneof('computation')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 21,
          "status": "ok",
          "timestamp": 1614320155367
        },
        "id": "dBWGR0_gs8JJ",
        "outputId": "75c6f689-178e-4df9-fc3c-2724f7fffff6"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "HloModule xla_computation_add_numbers.7\n",
            "\n",
            "ENTRY xla_computation_add_numbers.7 {\n",
            "  constant.4 = pred[] constant(false)\n",
            "  parameter.1 = (s32[], s32[]) parameter(0)\n",
            "  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0\n",
            "  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1\n",
            "  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)\n",
            "  ROOT tuple.6 = (s32[]) tuple(add.5)\n",
            "}\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)\n",
        "print(xla_code.as_hlo_text())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QMOV9sbWtzNB"
      },
      "source": [
        "Think of representation of JAX computations as XLA code as being the functional\n",
        "equivalent of `tf.GraphDef` for computations expressed in TensorFlow. It is\n",
        "portable and executable in a variety of environments that support XLA, just like\n",
        "the `tf.GraphDef` can be executed on any TensorFlow runtime."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IPN_fjW4uEjg"
      },
      "source": [
        "TFF provides a runtime stack based on the XLA compiler as a backend. You can\n",
        "activate it as follows:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DoIkNfhftp2r"
      },
      "outputs": [],
      "source": [
        "tff.backends.xla.set_local_execution_context()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OrELxlxGucY0"
      },
      "source": [
        "Now, you can execute the computation we defined above:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 68,
          "status": "ok",
          "timestamp": 1614320155486
        },
        "id": "ZLfwnzh5ubeA",
        "outputId": "8e92160e-c88f-4347-e340-f87e73bc7740"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "5"
            ]
          },
          "execution_count": 8,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "add_numbers(2, 3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Flmb2lyGuhM6"
      },
      "source": [
        "Easy enough. Let's go with the blow and do something more complicated, such as\n",
        "MNIST."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yxbDbwa-w0rj"
      },
      "source": [
        "## Example of MNIST training with canned API\n",
        "\n",
        "As usual, we start by defining a bunch of TFF types for batches of data,\n",
        "and for the model (remember, TFF is a strongly-typed framework)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AyAAQEtxu2Jg"
      },
      "outputs": [],
      "source": [
        "import collections\n",
        "\n",
        "BATCH_TYPE = collections.OrderedDict([\n",
        "    ('pixels', tff.TensorType(np.float32, (50, 784))),\n",
        "    ('labels', tff.TensorType(np.int32, (50,)))\n",
        "])\n",
        "\n",
        "MODEL_TYPE = collections.OrderedDict([\n",
        "    ('weights', tff.TensorType(np.float32, (784, 10))),\n",
        "    ('bias', tff.TensorType(np.float32, (10,)))\n",
        "])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J4rCcC4lvBKm"
      },
      "source": [
        "Now, let's define a loss function for the model in JAX, taking the model and a\n",
        "single batch of data as a parameter:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l7Y-PyVJvGlB"
      },
      "outputs": [],
      "source": [
        "def loss(model, batch):\n",
        "  y = jax.nn.softmax(\n",
        "      jax.numpy.add(\n",
        "          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))\n",
        "  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)\n",
        "  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BGnwZTjUxJi4"
      },
      "source": [
        "Now, one way to go is to use a canned API. Here's an example of how you can use\n",
        "our API to create a training process based on the loss function just defined."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pSoB_jW3xfRd"
      },
      "outputs": [],
      "source": [
        "STEP_SIZE = 0.001\n",
        "\n",
        "trainer = tff.experimental.learning.build_jax_federated_averaging_process(\n",
        "    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g8TS_LDnxktY"
      },
      "source": [
        "You can use the above just as you would use a trainer build from a `tf.Keras`\n",
        "model in TensorFlow. For example, here's how you can create the initial model\n",
        "for training:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 43,
          "status": "ok",
          "timestamp": 1614320155674
        },
        "id": "roJas9RGx9Sv",
        "outputId": "2ea5b9ce-77b7-4168-89da-69f2553c098f"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],\n",
              "       [0., 0., 0., ..., 0., 0., 0.],\n",
              "       [0., 0., 0., ..., 0., 0., 0.],\n",
              "       ...,\n",
              "       [0., 0., 0., ..., 0., 0., 0.],\n",
              "       [0., 0., 0., ..., 0., 0., 0.],\n",
              "       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])"
            ]
          },
          "execution_count": 12,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "initial_model = trainer.initialize()\n",
        "initial_model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j9Ii7c24yDcX"
      },
      "source": [
        "In order to perform actual training, we need some data. Let's make random data\n",
        "to keep it simple. Since the data is random, we are going to evaluate\n",
        "on training data, since otherwise, with random eval data, it would be hard\n",
        "to expect the model to perform. Also, for this small-scale demo, we will\n",
        "not worry about randomly sampling clients (we leave it as an exercise to\n",
        "the user to explore those types of changes by following the templates from\n",
        "other tutorials):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "z7hnu5QUwYOG"
      },
      "outputs": [],
      "source": [
        "def random_batch():\n",
        "  pixels = np.random.uniform(\n",
        "      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)\n",
        "  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)\n",
        "  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])\n",
        "\n",
        "NUM_CLIENTS = 2\n",
        "NUM_BATCHES = 10\n",
        "\n",
        "train_data = [\n",
        "    [random_batch() for _ in range(NUM_BATCHES)]\n",
        "    for _ in range(NUM_CLIENTS)]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tET5dIv3yoPR"
      },
      "source": [
        "With that, we can perform a single step of training, as follows:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 4797,
          "status": "ok",
          "timestamp": 1614320160529
        },
        "id": "c2BDLhjdyrtX",
        "outputId": "804ffbd6-3ecb-412e-9802-82264073b05c"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,\n",
              "         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],\n",
              "       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,\n",
              "         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],\n",
              "       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,\n",
              "         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],\n",
              "       ...,\n",
              "       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,\n",
              "         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],\n",
              "       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,\n",
              "         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],\n",
              "       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,\n",
              "         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,\n",
              "        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,\n",
              "        7.7109929e-05, -9.1987278e-04], dtype=float32))])"
            ]
          },
          "execution_count": 14,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "trained_model = trainer.next(initial_model, train_data)\n",
        "trained_model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QCdbr2ivy2oW"
      },
      "source": [
        "Let's evalue the result of the training step. To keep it easy, we can evaluate\n",
        "it in in a centralized fashion:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 670,
          "status": "ok",
          "timestamp": 1614320161206
        },
        "id": "EgzZd7GEzMyi",
        "outputId": "055ad19b-32de-476e-a5bb-c33ff8b53d09"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.3025854\n",
            "2.282762\n"
          ]
        }
      ],
      "source": [
        "import itertools\n",
        "eval_data = list(itertools.chain.from_iterable(train_data))\n",
        "\n",
        "def average_loss(model, data):\n",
        "  return np.mean([loss(model, batch) for batch in data])\n",
        "\n",
        "print (average_loss(initial_model, eval_data))\n",
        "print (average_loss(trained_model, eval_data))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PhsYzZsM05Jd"
      },
      "source": [
        "The loss is decreasing. Great! Now, let's run this over multiple rounds:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 92976,
          "status": "ok",
          "timestamp": 1614320254189
        },
        "id": "nB1hZJyR1AN9",
        "outputId": "2e6ac884-1ebf-47b8-dbf9-e461a74c2305"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.2685437\n",
            "2.257856\n",
            "2.2495182\n",
            "2.2428129\n",
            "2.2372835\n",
            "2.2326245\n",
            "2.2286277\n",
            "2.2251441\n",
            "2.2220676\n",
            "2.219318\n",
            "2.2168345\n",
            "2.2145717\n",
            "2.2124937\n",
            "2.2105706\n",
            "2.2087805\n",
            "2.2071042\n",
            "2.2055268\n",
            "2.2040353\n",
            "2.2026198\n",
            "2.2012706\n"
          ]
        }
      ],
      "source": [
        "NUM_ROUNDS = 20\n",
        "for _ in range(NUM_ROUNDS):\n",
        "  trained_model = trainer.next(trained_model, train_data)\n",
        "  print(average_loss(trained_model, eval_data))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jBC-Mp930_Yn"
      },
      "source": [
        "As you see, using JAX with TFF is not that different, albeit the experimental\n",
        "APIs are not yet on par with the TensorFlow APIs functionality-wise."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PFg1CaGf1vqe"
      },
      "source": [
        "## Under the hood\n",
        "\n",
        "If you prefer not to use our canned API, you can implement your own custom\n",
        "computations, much in the same way as how you have seen it done in the\n",
        "custom algorithms tutorials for TensorFlow, except that you will use JAX's\n",
        "mechanism for gradient descent. For example, below is how you can define a\n",
        "JAX computation that updates the model on a single minibatch:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Hw-p7DAN2G7B"
      },
      "outputs": [],
      "source": [
        "@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)\n",
        "def train_on_one_batch(model, batch):\n",
        "  grads = jax.api.grad(loss)(model, batch)\n",
        "  return collections.OrderedDict([\n",
        "      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']\n",
        "  ])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a0pDsem22R6x"
      },
      "source": [
        "Here's how you can test that it works:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 479,
          "status": "ok",
          "timestamp": 1614320255326
        },
        "id": "LBB84zAO2Y4b",
        "outputId": "33f2abef-3cb8-41e3-e26d-33c1fc285ecc"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.3025854\n",
            "2.2977567\n"
          ]
        }
      ],
      "source": [
        "sample_batch = random_batch()\n",
        "trained_model = train_on_one_batch(initial_model, sample_batch)\n",
        "print(average_loss(initial_model, [sample_batch]))\n",
        "print(average_loss(trained_model, [sample_batch]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7SE99B3L3McK"
      },
      "source": [
        "One caveat of working with JAX is that it does not offer the equivalent of\n",
        "`tf.data.Dataset`. Thus, in order to iterate over datasets, you will need to\n",
        "use TFF's declarative contructs for operations on sequences, such as the one\n",
        "shown below:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QT3rqv4x3ipO"
      },
      "outputs": [],
      "source": [
        "@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))\n",
        "def train_on_one_client(model, batches):\n",
        "  return tff.sequence_reduce(batches, model, train_on_one_batch)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "csx5qskW3quD"
      },
      "source": [
        "Let's see that it works:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 43411,
          "status": "ok",
          "timestamp": 1614320298776
        },
        "id": "y9-NLvAM3sE_",
        "outputId": "494fef0f-3972-43c8-def7-40dc2869343f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.3025854\n",
            "2.2284968\n"
          ]
        }
      ],
      "source": [
        "sample_dataset = [random_batch() for _ in range(100)]\n",
        "trained_model = train_on_one_client(initial_model, sample_dataset)\n",
        "print(average_loss(initial_model, sample_dataset))\n",
        "print(average_loss(trained_model, sample_dataset))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2D0ApaBl4JMf"
      },
      "source": [
        "The computation that performs a single round of training looks just like the\n",
        "one you may have seen in the TensorFlow tutorials:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RVqbU1p-9i2j"
      },
      "outputs": [],
      "source": [
        "@tff.federated_computation(\n",
        "    tff.FederatedType(MODEL_TYPE, tff.SERVER),\n",
        "    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))\n",
        "def train_one_round(model, federated_data):\n",
        "  locally_trained_models = tff.federated_map(\n",
        "      train_on_one_client,\n",
        "      collections.OrderedDict([\n",
        "          ('model', tff.federated_broadcast(model)),\n",
        "          ('batches', federated_data)]))\n",
        "  return tff.federated_mean(locally_trained_models)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VHzYmP6K-HBs"
      },
      "source": [
        "Let's see that it works:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 4775,
          "status": "ok",
          "timestamp": 1614320303608
        },
        "id": "1XFsQxT2-GVT",
        "outputId": "35c4517e-5f10-40ec-f0da-3c89b9cb4d18"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.3025854\n",
            "2.282762\n"
          ]
        }
      ],
      "source": [
        "trained_model = train_one_round(initial_model, train_data)\n",
        "print(average_loss(initial_model, eval_data))\n",
        "print(average_loss(trained_model, eval_data))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oDiW1NgL9iT7"
      },
      "source": [
        "As you see, using JAX in TFF, whether via canned APIs, or directly using the\n",
        "low-level TFF constructs, is similar to using TFF with TensorFlow.\n",
        "Stay tuned for future updates, and if you'd like to see better support for\n",
        "interoperability across ML frameworks, feel free to send us a pull request!"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "Experimental support for JAX in TFF",
      "provenance": [
        {
          "file_id": "1jn7qngZ3A3GVP4xqWAk6zRtsXLW1VBIE",
          "timestamp": 1614292188153
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
